{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chemprop MPNN models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.models.model import MPNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Composition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A Chemprop `MPNN` model is made up of several submodules including a [message passing](./message_passing.ipynb) layer, an [aggregation](./aggregation.ipynb) layer, an optional batch normalization layer, and a [predictor](./predictor.ipynb) feed forward network layer. `MPNN` defines the training and predicting logic used by `lightning` when using a Chemprop model in their framework. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPNN(\n",
       "  (message_passing): BondMessagePassing(\n",
       "    (W_i): Linear(in_features=86, out_features=300, bias=False)\n",
       "    (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "    (W_o): Linear(in_features=372, out_features=300, bias=True)\n",
       "    (dropout): Dropout(p=0.0, inplace=False)\n",
       "    (tau): ReLU()\n",
       "    (V_d_transform): Identity()\n",
       "    (graph_transform): Identity()\n",
       "  )\n",
       "  (agg): NormAggregation()\n",
       "  (bn): Identity()\n",
       "  (predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=300, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0]])\n",
       "    (output_transform): Identity()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metrics): ModuleList(\n",
       "    (0-1): 2 x MSE(task_weights=[[1.0]])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from chemprop.nn import BondMessagePassing, NormAggregation, RegressionFFN\n",
    "\n",
    "mp = BondMessagePassing()\n",
    "agg = NormAggregation()\n",
    "ffn = RegressionFFN()\n",
    "\n",
    "basic_model = MPNN(mp, agg, ffn)\n",
    "basic_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Batch normalization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Batch normalization can improve training by keeping the inputs to the FFN small and centered around zero. It is off by default, but can be turned on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPNN(\n",
       "  (message_passing): BondMessagePassing(\n",
       "    (W_i): Linear(in_features=86, out_features=300, bias=False)\n",
       "    (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "    (W_o): Linear(in_features=372, out_features=300, bias=True)\n",
       "    (dropout): Dropout(p=0.0, inplace=False)\n",
       "    (tau): ReLU()\n",
       "    (V_d_transform): Identity()\n",
       "    (graph_transform): Identity()\n",
       "  )\n",
       "  (agg): NormAggregation()\n",
       "  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=300, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0]])\n",
       "    (output_transform): Identity()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metrics): ModuleList(\n",
       "    (0-1): 2 x MSE(task_weights=[[1.0]])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MPNN(mp, agg, ffn, batch_norm=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MPNN` also configures the optimizer used by lightning during training. The `torch.optim.Adam` optimizer is used with a Noam learning rate scheduler (defined in `chemprop.scheduler.NoamLR`). The following parameters are customizable:\n",
    "\n",
    " - number of warmup epochs, defaults to 2\n",
    " - the initial learning rate, defaults to $10^{-4}$\n",
    " - the max learning rate, defaults to $10^{-3}$\n",
    " - the final learning rate, defaults to $10^{-4}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MPNN(mp, agg, ffn, warmup_epochs=5, init_lr=1e-3, max_lr=1e-2, final_lr=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "During the validation and testing loops, lightning will use the metrics stored in `MPNN` to evaluate the current model's performance. The `MPNN` has a default metric defined by the type of predictor used. Other [metrics](../metrics.ipynb) can be given to `MPNN` to use instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.nn import metrics\n",
    "\n",
    "metrics_list = [metrics.RMSE(), metrics.MAE()]\n",
    "model = MPNN(mp, agg, ffn, metrics=metrics_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fingerprinting and encoding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MPNN` has two helper functions to get the hidden representations at different parts of the model. The fingerprint is the learned representation of the message passing layer after aggregation and batch normalization. The encoding is the hidden representation after a number of layers of the predictor. See the predictor notebook for more details. Note that the 0th encoding is equivalent to the fingerprint."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Example batch for the model. See the [data notebooks](../data/dataloaders.ipynb) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from chemprop.data import MoleculeDatapoint, MoleculeDataset\n",
    "from chemprop.data import build_dataloader\n",
    "\n",
    "smis = [\"C\" * i for i in range(1, 4)]\n",
    "ys = np.random.rand(len(smis), 1)\n",
    "dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])\n",
    "dataloader = build_dataloader(dataset)\n",
    "batch = next(iter(dataloader))\n",
    "bmg, V_d, X_d, *_ = batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0333],\n",
       "        [0.0331],\n",
       "        [0.0332]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "basic_model(bmg, V_d, X_d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 300])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "basic_model.fingerprint(bmg, V_d, X_d).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 300])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "basic_model.encoding(bmg, V_d, X_d, i=1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(basic_model.fingerprint(bmg, V_d, X_d) == basic_model.encoding(bmg, V_d, X_d, i=0)).all()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
